import pandas as pd
import numpy as np
from tqdm import tqdm
import os
from collections import defaultdict
from itertools import product

import utils
from utils import LegislatorsDataset, make_rank_series

from utils import permute_data, ConditionalMajorityVote, make_indicator, \
                  party_indicator, load_csv, to_csv, predict_and_score

from utils import get_triu
from pingouin import partial_corr

np.random.seed(35897150)

include_party_var = True
sep_party_vars = False
remove_present = False
only_bills = True
only_house = False
congresses_modeled = [110, 111, 112, 113, 114, 115, 116, 117 ]
congresses = [111, 112, 113, 114, 115, 116, 117 ]
congress = 114
from scipy.stats import pearsonr

from sklearn.metrics import pairwise_distances
votes_metric = 'cityblock'
financials_metric = 'cosine'

merge_third_parties = True
name_field = 'Name_merged'

n_components = 25
n_components_bill_embed = 25
cv_folds = 10
analysis_dir = '.'

embed_field_name = 'PCAEmbed'
corr_method = 'spearman'
recode_independents = False
run_permutation = False

n_components = n_components if embed_field_name.startswith('PCA') else None

results_dirname = 'results-RSA'

try: os.mkdir(analysis_dir)
except FileExistsError: pass

try: os.mkdir(os.path.join(analysis_dir, results_dirname))
except FileExistsError: pass

usecols = [ 'congress', 'chamber', 'rollnumber', 'Name', 'Name_merged', 
            'bill_id', 'crs_policy_area', 'vote_result', 'vote_desc', 
            'yea_count', 'nay_count', 'vote_question', 'sponsor', 
            'sponsor_party', 'sponsor_state', 'primary_subject', 
            'short_title', 'summary_short', 'party', 'cast_description' ]

print('loading in bill models and vote data')
try:
  print(df_votes_.head())
except: 
  df_votes_ = load_csv(os.path.join(analysis_dir, \
                    '%ith_to_%ith_congress-vote_by_member_procced.csv' \
                    % (congresses_modeled[0], congresses_modeled[-1])), 
                    usecols=usecols)

#print('loading in data')
try: print(df_members.head())
except:
 try:
  #df_votes = load_csv(os.path.join(analysis_dir, \
  #                    '%ith_to_%ith_congress-vote_by_member_procced.csv' \
  #                    % (congresses[0], congresses[-1])))
  df_members = load_csv(os.path.join(analysis_dir, \
                      '%ith_to_%ith_congress-members-procced.csv' \
                      % (congresses_modeled[0], congresses_modeled[-1])))
 except FileNotFoundError:
  print('did not find data... compiling...')
  #df_votes = []
  df_members = []
  for congress in congresses_modeled:
    print('\t%ith congress' % congress)
    df_members_ = load_csv(os.path.join(analysis_dir, \
                           '%ith_congress-members-procced.csv' % congress))
    df_members.append(df_members_)
  
  df_members = pd.concat(df_members)
  to_csv(df_members, os.path.join(analysis_dir, \
                      '%ith_to_%ith_congress-members-procced.csv' \
                      % (congresses_modeled[0], congresses_modeled[-1])))


df_votes = df_votes_.copy()

if merge_third_parties:
  #relabel parties
  party2label = defaultdict(lambda: 'I', { 'Democrat': 'D', 'Republican': 'R' })
  df_votes['legislator_party'] = df_votes.party.apply(lambda p: party2label[p])
  df_votes['sponsor_party'] = df_votes.sponsor_party.apply(
                                     lambda s: (s if s in ['D', 'R'] else 'I'))
else:
  df_votes['legislator_party'] = np.array(df_votes.party)


#initialize a Legislators dataset
Legislators = utils.load(os.path.join(analysis_dir, 'Legislators.lds'))

df_votes = df_votes[np.isin(df_votes[name_field], getattr(Legislators, name_field))]

#coded votes: 	vote_NVmerged merges Present and Abstain
#		cast_description has Yea, Nay, Present, and Abstain
df_votes = df_votes[df_votes.cast_description != 'Not Voting (Abstention)']

if remove_present:
  df_votes = df_votes[df_votes.cast_description != 'Present']

if only_bills:
  #df_votes = df_votes[np.logical_or(df_votes.vote_question == 'On Passage', \
  #                        df_votes.vote_question == 'On Passage of the Bill')]
  df_votes = df_votes[np.isin(df_votes.vote_question, \
                              [ 'On Passage',  #House
                                #'On Motion to Suspend the Rules and Pass', \
                                #'On Agreeing to the Resolution', \
                                #'On Motion to Suspend the Rules and Pass, as Amended', \
                                #'On Motion to Recommit with Instructions', \
                                'On Passage of the Bill', #Senate
                                'On Agreeing to the Resolution', #House
                                'On the Resolution', #Senate
                                #'On the Motion', 
                                #'On the Cloture Motion', \
                                #'On the Amendment', \
                                #'On Agreeing to the Senate Amendment', \
                                #'On Agreeing to the Resolution, as Amended', \
                                #'On Overriding the Veto', \
                                'On the Joint Resolution', #Senate
                                #'On Consideration of the Joint Resolution', #House
                                #'On the Nomination', \
                               ])]

#permute the dataset
df_votes = df_votes.iloc[np.random.permutation(df_votes.shape[0])]

#set up pipeline for turning vote records into ordered class labels (int)
vote2ordinal__ = np.vectorize(lambda vote: \
                    { 'Nay': 0, 'Present': 1, 'Yea': 2 }[vote])

def vote2ordinal_(y):
  for i in range(y.min(), y.max()+1):
    if not np.any(y == i):
      y[y > i] -= 1
  return y

vote2ordinal = lambda y: vote2ordinal_(vote2ordinal__(y))

#recode votes to Nay=-1, Present=0, Yea=1
df_votes['vote'] = vote2ordinal(df_votes.cast_description) - 1


import re
def proc_name(name):
  #remove middle initials
  name_ = [] #name.split()
  for word in name.split():
    if re.fullmatch('([A-Z]\.)+', word):
      pass
    else: name_.append(word)
  if len(name_) < 2:
    return name
  else:
    return ' '.join([name_[0], name_[-1]])

proc_names = np.vectorize(proc_name)

correlations = {}
dist_dfs = {}
try: del correlation_stats
except NameError: pass

Legislators = utils.load(os.path.join(analysis_dir, 'Legislators.lds'))
for congress in congresses:
 for chamber in ['House', 'Senate']:
  #reload legislators to fix spot checks
  if recode_independents: #reload Legislators with original codes
    Legislators = utils.load(os.path.join(analysis_dir, 'Legislators.lds'))
  
  df_pivot = df_votes[np.logical_and(df_votes.chamber == chamber, 
                                     df_votes.congress == congress)][
                                     ['bill_id', 'Name_merged', 'party', 'vote']]
  
  #pivot dataset to bill_ids as columns, legislator names as rows
  #	yields the matrix of vote vectors
  #	missing values (Abstentions) coded as 0
  df_pivot = df_pivot.pivot_table(columns='bill_id', 
                                  index='Name_merged', fill_value=0)
  
  names = df_pivot.index
  vote_vector = np.array(df_pivot)
  distances = pairwise_distances(vote_vector, metric=votes_metric)
    
  
  dist_df = pd.DataFrame(data=distances, index=names, columns=names)
  
  leg2party = lambda name: Legislators[name].party
  party2color = lambda p: defaultdict(lambda: 'tab:green', 
                                      { 'Democrat': 'tab:blue', 
                                        'Republican': 'tab:red' })[p]
  leg2color = np.vectorize(lambda name: party2color(leg2party(name)))
  
  import seaborn as sns
  import matplotlib.pyplot as plt
  from matplotlib import rc
  
  plt.clf()
  cg = sns.clustermap(dist_df, metric='cosine', 
                               row_cluster=True, 
                               col_cluster=True, 
                               cbar=False, 
                               cmap='RdBu_r')
  
  cg.ax_row_dendrogram.set_visible(False)
  cg.ax_col_dendrogram.set_visible(False)
  cg.cax.set_visible(False)
  cg.ax_cbar.remove()
  cg.ax_row_dendrogram.remove()
  cg.ax_col_dendrogram.remove()
  
  cg.ax_heatmap.set_xlabel(None)
  cg.ax_heatmap.set_ylabel(None)
  
  xticklabels = [ t._text for t in cg.ax_heatmap.get_xticklabels() ]
  tick_colors = leg2color(xticklabels)
  for label, color in zip(cg.ax_heatmap.get_xticklabels(), tick_colors):
    label.set_color(color)
  
  cg.ax_heatmap.yaxis.tick_left()
  yticklabels = [ t._text for t in cg.ax_heatmap.get_yticklabels() ]
  tick_colors = leg2color(yticklabels)
  for label, color in zip(cg.ax_heatmap.get_yticklabels(), tick_colors):
    label.set_color(color)
  
  #make the same graph for the financial data, using the same axis order
  reordered_ind = cg.dendrogram_col.reordered_ind
  ordered_names = names[reordered_ind]
  leg_vectors = getattr(Legislators[ordered_names], embed_field_name)[:,:n_components]
  fin_distances = pairwise_distances(leg_vectors, metric=financials_metric)
  fdist_df = pd.DataFrame(data=fin_distances, index=ordered_names, columns=ordered_names)
  
  cg2 = sns.clustermap(fdist_df, metric='cosine', 
                               row_cluster=True, 
                               col_cluster=True, 
                               cbar=False, 
                               cmap='RdBu_r')
  
  cg2.ax_row_dendrogram.set_visible(False)
  cg2.ax_col_dendrogram.set_visible(False)
  cg2.cax.set_visible(False)
  cg2.ax_cbar.remove()
  cg2.ax_row_dendrogram.remove()
  cg2.ax_col_dendrogram.remove()
  
  reordered_ind2 = cg2.dendrogram_col.reordered_ind
  ordered_names2 = ordered_names[reordered_ind2]
  
  cg2.ax_heatmap.set_xlabel(None)
  cg2.ax_heatmap.set_ylabel(None)
  
  xticklabels = [ t._text for t in cg2.ax_heatmap.get_xticklabels() ]
  tick_colors = leg2color(xticklabels)
  for label, color in zip(cg2.ax_heatmap.get_xticklabels(), tick_colors):
    label.set_color(color)
  
  cg2.ax_heatmap.yaxis.tick_left()
  yticklabels = [ t._text for t in cg2.ax_heatmap.get_yticklabels() ]
  tick_colors = leg2color(yticklabels)
  for label, color in zip(cg2.ax_heatmap.get_yticklabels(), tick_colors):
    label.set_color(color)
  plt.clf()
  
  #rsa analysis
  tick_int = 1 if chamber == 'Senate' else 5
  
  vote_dists = distances[np.triu_indices(distances.shape[0], 1)]
  fin_dists_ = pairwise_distances(getattr(Legislators[names], embed_field_name)[:,:n_components], metric='cosine')
  financial_dists = fin_dists_[np.triu_indices(fin_dists_.shape[0], 1)]
  
  from pingouin import corr
  corr_stats = corr(vote_dists, financial_dists, method=corr_method)
  corr_stats['congress'] = congress
  corr_stats['chamber'] = chamber
  corr_stats['corr_type'] = 'full_'+corr_method
  
  correlations[congress] = { 'full': corr_stats }
  
  #improve the visualization with subplots
  fig, axes = plt.subplots(1,2, figsize=(85,40))
  
  reordered_dists = distances[reordered_ind][:,reordered_ind]
  vote_dist_data = pd.DataFrame(data=reordered_dists, index=ordered_names, columns=ordered_names)
  financial_dist_data = pd.DataFrame(data=fin_distances, index=ordered_names, columns=ordered_names)
  
  sns.heatmap(vote_dist_data, ax=axes[0], cmap='RdBu_r', 
              cbar=False, xticklabels=tick_int, yticklabels=tick_int)
  sns.heatmap(financial_dist_data, ax=axes[1], cmap='RdBu_r', cbar=False, 
              xticklabels=tick_int, yticklabels=tick_int)
  
  axes[0].set_xlabel(None); axes[1].set_xlabel(None)
  axes[0].set_ylabel(None); axes[1].set_ylabel(None)
  
  cg, cg2 = axes
  xticklabels = [ t._text for t in cg.get_xticklabels() ]
  tick_colors = leg2color(xticklabels)
  cg.set_xticklabels(proc_names(xticklabels))
  for label, color in zip(cg.get_xticklabels(), tick_colors):
    label.set_color(color)
  
  #cg.yaxis.tick_left()
  yticklabels = [ t._text for t in cg.get_yticklabels() ]
  tick_colors = leg2color(yticklabels)
  cg.set_yticklabels(proc_names(yticklabels))
  for label, color in zip(cg.get_yticklabels(), tick_colors):
    label.set_color(color)
    #label.set_text(proc_name(label.get_text()))
  
  
  xticklabels = [ t._text for t in cg2.get_xticklabels() ]
  tick_colors = leg2color(xticklabels)
  cg2.set_xticklabels(proc_names(xticklabels))
  for label, color in zip(cg2.get_xticklabels(), tick_colors):
    label.set_color(color)
    #label.set_text(proc_name(label.get_text()))
  
  #cg2.xaxis.tick_left()
  yticklabels = [ t._text for t in cg2.get_yticklabels() ]
  tick_colors = leg2color(yticklabels)
  cg2.set_yticklabels(proc_names(yticklabels))
  for label, color in zip(cg2.get_yticklabels(), tick_colors):
    label.set_color(color)
    #label.set_text(proc_name(label.get_text()))
  
  plt.setp(cg.get_xticklabels(), fontsize=25)
  plt.setp(cg2.get_xticklabels(), fontsize=25)
  plt.setp(cg.get_yticklabels(), fontsize=25)
  plt.setp(cg2.get_yticklabels(), fontsize=25)
  
  cg.set_title('Vote profiles', fontsize=55, pad=40)
  cg2.set_title('Financial profiles', fontsize=55, pad=40)
  
  cg.text(-.15, .5, chamber, rotation=90, ha='center', fontsize=80, 
          verticalalignment='center', transform=cg.transAxes)
  
  fig.subplots_adjust(top=.95, bottom=.125, left=.075, right=.975, wspace=.125)
  fig.savefig(os.path.join(analysis_dir, results_dirname, 'legislators_RDMs-' + \
              '%ith_congress+state-%s-votesim_sorted.png' % (congress, chamber)))
  plt.clf()
  
  #visualization using the dendrogram sorting from FINANCIAL distances
  leg_vectors2 = getattr(Legislators[ordered_names2], embed_field_name)[:,:n_components]
  fin_distances2 = pairwise_distances(leg_vectors2, metric=financials_metric)
  #improve the visualization with subplots
  fig, axes = plt.subplots(1,2, figsize=(85,40))
  
  vote_dist_data = pd.DataFrame(data=reordered_dists[reordered_ind2][:,reordered_ind2], index=ordered_names2, columns=ordered_names2)
  financial_dist_data = pd.DataFrame(data=fin_distances2, index=ordered_names2, columns=ordered_names2)
  
  sns.heatmap(vote_dist_data, ax=axes[0], cmap='RdBu_r', cbar=False, 
              xticklabels=tick_int, yticklabels=tick_int)
  sns.heatmap(financial_dist_data, ax=axes[1], cmap='RdBu_r', 
              cbar=False, xticklabels=tick_int, yticklabels=tick_int)
  
  axes[0].set_xlabel(None); axes[1].set_xlabel(None)
  axes[0].set_ylabel(None); axes[1].set_ylabel(None)
  
  cg, cg2 = axes
  xticklabels = [ t._text for t in cg.get_xticklabels() ]
  tick_colors = leg2color(xticklabels)
  cg.set_xticklabels(proc_names(xticklabels))
  for label, color in zip(cg.get_xticklabels(), tick_colors):
    label.set_color(color)
    #label.set_text(proc_name(label.get_text()))
  
  #cg.yaxis.tick_left()
  yticklabels = [ t._text for t in cg.get_yticklabels() ]
  tick_colors = leg2color(yticklabels)
  cg.set_yticklabels(proc_names(yticklabels))
  for label, color in zip(cg.get_yticklabels(), tick_colors):
    label.set_color(color)
    #label.set_text(proc_name(label.get_text()))
  
  
  xticklabels = [ t._text for t in cg2.get_xticklabels() ]
  tick_colors = leg2color(xticklabels)
  cg2.set_xticklabels(proc_names(xticklabels))
  for label, color in zip(cg2.get_xticklabels(), tick_colors):
    label.set_color(color)
    #label.set_text(proc_name(label.get_text()))
  
  #cg2.xaxis.tick_left()
  yticklabels = [ t._text for t in cg2.get_yticklabels() ]
  tick_colors = leg2color(yticklabels)
  cg2.set_yticklabels(proc_names(yticklabels))
  for label, color in zip(cg2.get_yticklabels(), tick_colors):
    label.set_color(color)
    #label.set_text(proc_name(label.get_text()))
  
  plt.setp(cg.get_xticklabels(), fontsize=25)
  plt.setp(cg2.get_xticklabels(), fontsize=25)
  plt.setp(cg.get_yticklabels(), fontsize=25)
  plt.setp(cg2.get_yticklabels(), fontsize=25)
  
  cg.set_title('Vote profiles', fontsize=55, pad=40)
  cg2.set_title('Financial profiles', fontsize=55, pad=40)
  
  cg.text(-.15, .5, chamber, rotation=90, ha='center', fontsize=80, 
          verticalalignment='center', transform=cg.transAxes)
  
  #fig.tight_layout()
  fig.subplots_adjust(top=.95, bottom=.125, left=.075, right=.975, wspace=.125)
  fig.savefig(os.path.join(analysis_dir, results_dirname, 'legislators_RDMs-' + \
              '%ith_congress+state-%s-financialsim_sorted.png' % (congress, chamber)))
  
  plt.clf()
  
  if recode_independents:
    Legislators['Bernard Sanders'].party = 'Democrat'
    Legislators['Justin Amash'].party = 'Republican'
    Legislators['Paul Mitchell'].party = 'Republican'
    Legislators['Angus S. King'].party = 'Democrat'
    Legislators['Joseph I. Lieberman'].party = 'Democrat'
  
  parties = Legislators[names].party
  other_party = (np.expand_dims(parties, 0) != np.expand_dims(parties, 1)).astype(np.int8)
  party_dists = np.expand_dims(other_party[np.triu_indices(other_party.shape[0], 1)], 1)
  
  vote_dists = distances[np.triu_indices(distances.shape[0], 1)]
  fin_dists_ = pairwise_distances(getattr(Legislators[names], embed_field_name)[:,:n_components], metric='cosine')
  financial_dists = fin_dists_[np.triu_indices(fin_dists_.shape[0], 1)]
  
  states = Legislators[names].state_abbrev
  other_state = (np.expand_dims(states, 0) != np.expand_dims(states, 1)).astype(np.int8)
  state_dists = get_triu(other_state)
  
  n_samples = len(vote_dists)
  dists_df = pd.DataFrame({'congress': [congress]*n_samples, 
                           'chamber': [chamber]*n_samples, 
                           'vote_profile_dist': vote_dists, 
                           'financial_profile_dist': financial_dists, 
                           'party_dist': party_dists.squeeze(), 
                           'state_dist': state_dists })
  
  print(dists_df)
  print(dists_df.shape)
  
  dist_dfs[congress, chamber] = dists_df
  
  from pingouin import partial_corr
  partial_stats = partial_corr(data=dists_df, x='vote_profile_dist', 
                                              y='financial_profile_dist', 
                                              covar=['party_dist', 'state_dist'], 
                                              method=corr_method
                                              )
  partial_stats['congress'] = congress
  partial_stats['chamber'] = chamber
  partial_stats['corr_type'] = 'partial_'+corr_method
  
  partial_stats_st = partial_corr(data=dists_df, x='vote_profile_dist', 
                                              y='financial_profile_dist', 
                                              covar=['state_dist'], 
                                              method=corr_method
                                              )
  partial_stats_st['congress'] = congress
  partial_stats_st['chamber'] = chamber
  partial_stats_st['corr_type'] = 'partial_'+corr_method+'_state'
  
  
  partial_stats_pt = partial_corr(data=dists_df, x='vote_profile_dist', 
                                              y='financial_profile_dist', 
                                              covar=['party_dist'], 
                                              method=corr_method
                                              )
  partial_stats_pt['congress'] = congress
  partial_stats_pt['chamber'] = chamber
  partial_stats_pt['corr_type'] = 'partial_'+corr_method+'_party'
  
  new_stats = pd.concat([corr_stats, partial_stats, 
                         partial_stats_st, partial_stats_pt])
  try:
    correlation_stats = pd.concat([correlation_stats, new_stats])
  except NameError:
    correlation_stats = new_stats
  
  correlations[congress]['partial'] = partial_stats


correlation_stats.to_csv(os.path.join(analysis_dir, results_dirname, 'RSA_correlations.csv'))
correlation_stats_ = correlation_stats.copy()
del correlation_stats


if run_permutation:
 #generate a permutation distribution
 n_permutations = 100_000
 
 try: os.mkdir(os.path.join(analysis_dir, results_dirname))
 except: pass
 
 try:
  correlation_stats = pd.read_csv(os.path.join(
                         analysis_dir, results_dirname, 
                         'RSA_correlations+state+permutations.csv'))
  current_congress = correlation_stats.congress.max()
  cong_id = congresses.index(current_congress)
  current_congress_stats = correlation_stats[correlation_stats.congress == current_congress]
  if 'Senate' in current_congress_stats.chamber.unique():
    id_ = cong_id*2 + 2
  else:
    id_ = cong_id*2 + 1
 except FileNotFoundError:
  id_ = congresses.index(113)*2
 
 #begin at "116th house": entry 5
 for congress, chamber in list(product(congresses[:], ['House', 'Senate']))[id_:]:
  #for congress in congresses[2:]:
  #for chamber in ['House', 'Senate']:
  
  if recode_independents: #reload Legislators with original codes
    Legislators = utils.load(os.path.join(analysis_dir, 'Legislators.lds'))
  
  df_pivot = df_votes[np.logical_and(df_votes.chamber == chamber, 
                                     df_votes.congress == congress)][
                                     ['bill_id', 'Name_merged', 'party', 'vote']]
  df_pivot = df_pivot.pivot_table(columns='bill_id', 
                                  index='Name_merged', fill_value=0)
  
  #spot-check Independents
  if recode_independents:
    Legislators['Bernard Sanders'].party = 'Democrat'
    Legislators['Justin Amash'].party = 'Republican'
    Legislators['Paul Mitchell'].party = 'Republican'
    Legislators['Angus S. King'].party = 'Democrat'
    Legislators['Joseph I. Lieberman'].party = 'Democrat'
  
  names = df_pivot.index
  vote_vector = np.array(df_pivot)
  vote_distances = pairwise_distances(vote_vector, metric=votes_metric)
  vote_dists = get_triu(vote_distances)
  
  financial_vectors = getattr(Legislators[names], embed_field_name)[:,:n_components]
  financial_distances = pairwise_distances(financial_vectors, metric=financials_metric)
  financial_dists = get_triu(financial_distances)
  
  parties = Legislators[names].party
  other_party = (np.expand_dims(parties, 0) != np.expand_dims(parties, 1)).astype(np.int8)
  party_dists = get_triu(other_party)
  
  states = Legislators[names].state_abbrev
  other_state = (np.expand_dims(states, 0) != np.expand_dims(states, 1)).astype(np.int8)
  state_dists = get_triu(other_state)
  
  n_samples = len(vote_dists)
  dists_df = pd.DataFrame({'congress': [congress]*n_samples, 
                           'chamber': [chamber]*n_samples, 
                           'vote_profile_dist': vote_dists, 
                           'financial_profile_dist': financial_dists, 
                           'party_dist': party_dists, 
                           'state_dist': state_dists })
  
  corr_stats = partial_corr(data=dists_df, x='vote_profile_dist', 
                                           y='financial_profile_dist', 
                                           method=corr_method
                                           )
  partial_stats = partial_corr(data=dists_df, x='vote_profile_dist', 
                                              y='financial_profile_dist', 
                                              covar=['party_dist', 'state_dist'], 
                                              #x_covar='party_dist', 
                                              #y_covar='party_dist', 
                                              method=corr_method
                                              )
  new_stats = pd.concat([corr_stats, partial_stats])
  new_stats['congress'] = congress
  new_stats['chamber'] = chamber
  new_stats['stat_type'] = ['full_'+corr_method, 'partial_'+corr_method]
  new_stats['is_perm_sample'] = False
  try:
    correlation_stats = pd.concat([correlation_stats, new_stats])
  except NameError:
    correlation_stats = new_stats
  
  #run permutation stats
  print('permutation test for %ith %s' % (congress, chamber))
  new_stats = []
  for run in tqdm(range(n_permutations)):
    n_observations = len(names)
    perm = np.random.permutation(n_observations)
    #vote_distances_ = pairwise_distances(vote_vector[\
    #                           np.random.permutation(n_observations)], 
    #                           metric=votes_metric)
    perm = np.random.permutation(n_observations)
    vote_dists_ = get_triu(vote_distances[perm][:,perm])
    
    #financial_distances_ = pairwise_distances(financial_vectors[ \
    #                               np.random.permutation(n_observations)], \
    #                               metric=financials_metric)
    perm = np.random.permutation(n_observations)
    financial_dists_ = get_triu(financial_distances[perm][:,perm])
    
    #parties_ = Legislators[names].party[np.random.permutation(n_observations)]
    #other_party_ = (np.expand_dims(parties_, 0) != np.expand_dims(parties_, 1)).astype(np.int8)
    perm = np.random.permutation(n_observations)
    party_dists_ = get_triu(other_party[perm][:,perm])
    
    perm = np.random.permutation(n_observations)
    state_dists_ = get_triu(other_state[perm][:,perm])
    
    dists_df_ = pd.DataFrame({'congress': [congress]*n_samples, 
                           'chamber': [chamber]*n_samples, 
                           'vote_profile_dist': vote_dists_, 
                           'financial_profile_dist': financial_dists_, 
                           'party_dist': party_dists_,
                           'state_dist': state_dists_ })
    
    corr_stats_ = partial_corr(data=dists_df_, x='vote_profile_dist', 
                                               y='financial_profile_dist', 
                                               )
    
    partial_stats_ = partial_corr(data=dists_df_, x='vote_profile_dist', 
                                                  y='financial_profile_dist', 
                                                  covar=['party_dist', 'state_dist'], 
                                                  #x_covar='party_dist', 
                                                  #y_covar='party_dist', 
                                                  method=corr_method
                                                  )
    new_stats_ = pd.concat([corr_stats_, partial_stats_])
    new_stats_['congress'] = congress
    new_stats_['chamber'] = chamber
    new_stats_['stat_type'] = ['full_'+corr_method, 'partial_'+corr_method]
    new_stats_['is_perm_sample'] = True
    new_stats.append(new_stats_)
    #correlation_stats = pd.concat([correlation_stats, new_stats_])
  
  new_stats = pd.concat(new_stats)
  correlation_stats = pd.concat([correlation_stats, new_stats])
  correlation_stats.to_csv(os.path.join(analysis_dir, 
                results_dirname, 'RSA_correlations+state+permutations.csv'))
 
 
 #correlation_stats.congress.unique()
 
 import numpy as np
 from statsmodels.stats.multitest import multipletests
 true_stats = correlation_stats[~correlation_stats.is_perm_sample]
 perm_stats = correlation_stats[correlation_stats.is_perm_sample]
 
 #compute permutation test results
 p_vals = []
 for congress, chamber, stat_type, r in true_stats[['congress', 'chamber', 'stat_type', 'r']].iloc:
  print(congress, chamber, stat_type, r)
  permutations = perm_stats[np.all(np.vstack([ perm_stats.congress == congress, 
                                               perm_stats.chamber == chamber, 
                                               perm_stats.stat_type == stat_type ]), axis=0)]
  perm_r = np.array(permutations.r)
  p_2tailed = (np.abs(r) <= np.abs(perm_r)).sum()/perm_r.shape[0]
  p_vals.append(p_2tailed)
  print('p_2tailed', p_2tailed)
 
 true_stats['p_permutation'] = p_vals
 reject, p_bonf, _, alpha_corrected = multipletests(p_vals, method='bonferroni')
 true_stats['p_bonferroni'] = p_bonf
 print(true_stats.groupby(['chamber', 'stat_type', 'congress']).sum())
 
 true_stats.to_csv(os.path.join(analysis_dir, 
                   results_dirname, 'rsa_correlations+perm_corrected.csv'))




